Understanding “Deep Double Descent”
If you’re not familiar with the double descent phenomenon, I think you should be. I consider double descent to be one of the most interesting and surprising recent results in analyzing and understanding modern machine learning. Today, Preetum et al. released a new paper, “Deep Double Descent,” which I think is a big further advancement in our understanding of this phenomenon. I’d highly recommend at least reading the summary of the paper on the OpenAI blog. However, I will also try to summarize the paper here, as well as give a history of the literature on double descent and some of my personal thoughts.
Prior work
The double descent phenomenon was first discovered by Mikhail Belkin et al., who were confused by the phenomenon wherein modern ML practitioners would claim that “bigger models are always better” despite standard statistical machine learning theory predicting that bigger models should be more prone to overfitting. Belkin et al. discovered that the standard bias-variance tradeoff picture actually breaks down once you hit approximately zero training error—what Belkin et al. call the “interpolation threshold.” Before the interpolation threshold, the bias-variance tradeoff holds and increasing model complexity leads to overfitting, increasing test error. After the interpolation threshold, however, they found that test error actually starts to go down as you keep increasing model complexity! Belkin et al. demonstrated this phenomenon in simple ML methods such as decision trees as well as simple neural networks trained on MNIST. Here’s the diagram that Belkin et al. use in their paper to describe this phenomenon:
Belkin et al. describe their hypothesis for what’s happening as follows:
All of the learned predictors to the right of the interpolation threshold fit the training data perfectly and have zero empirical risk. So why should some—in particular, those from richer functions classes—have lower test risk than others? The answer is that the capacity of the function class does not necessarily reflect how well the predictor matches the inductive bias appropriate for the problem at hand. [The inductive bias] is a form of Occam’s razor: the simplest explanation compatible with the observations should be preferred. By considering larger function classes, which contain more candidate predictors compatible with the data, we are able to find interpolating functions that [are] “simpler”. Thus increasing function class capacity improves performance of classifiers.
I think that what this is saying is pretty magical: in the case of neural nets, it’s saying that SGD just so happens to have the right inductive biases that letting SGD choose which model it wants the most out of a large class of models with the same training performance yields significantly better test performance. If you’re right on the interpolation threshold, you’re effectively “forcing” SGD to choose from a very small set of models with perfect training accuracy (maybe only one realistic option), thus ignoring SGD’s inductive biases completely—whereas if you’re past the interpolation threshold, you’re letting SGD choose which of many models with perfect training accuracy it prefers, thus allowing SGD’s inductive bias to shine through.
I think this is strong evidence for the critical importance of implicit simplicity and speed priors in making modern ML work. However, such biases also produce strong incentives for mesa-optimization (since optimizers are simple, compressed policies) and pseudo-alignment (since simplicity and speed penalties will favor simpler, faster proxies). Furthermore, the arguments for the universal prior and minimal circuits being malign suggest that such strong simplicity and speed priors could also produce an incentive for deceptive alignment.
“Deep Double Descent”
Now we get to Preetum et al.’s new paper, “Deep Double Descent.” Here are just some of the things that Preetum et al. demonstrate in “Deep Double Descent:”
double descent occurs across a wide variety of different model classes, including ResNets, standard CNNs, and Transformers, as well as a wide variety of different tasks, including image classification and language translation,
double descent occurs not just as a function of model size, but also as a function of training time and dataset size, and
since double descent can happen as a function of dataset size, more data can lead to worse test performance!
Crazy stuff. Let’s try to walk through each of these results in detail and understand what’s happening.
First, double descent is a highly universal phenomenon in modern deep learning. Here is double descent happening for ResNet18 on CIFAR-10 and CIFAR-100:
And again for a Transformer model on German-to-English and English-to-French translation:
All of these graphs, however, are just showcasing the standard Belkin et al.-style double descent over model size (what Preetum et al. call “model-wise double descent”). What’s really interesting about “Deep Double Descent,” however, is that Preetum et al. also demonstrate that the same thing can happen for training time (“epoch-wise double descent”) and a similar thing for dataset size (“sample-wise non-monotonicity”).
First, let’s look at epoch-wise double descent. Take a look at these graphs for ResNet18 on CIFAR-10:
There’s a bunch of crazy things happening here which are worth pointing out. First, the obvious: epoch-wise double descent is definitely a thing—holding model size fixed and training for longer exhibits the standard double descent behavior. Furthermore, the peak happens right at the interpolation threshold where you hit zero training error. Second, notice where you don’t get epoch-wise double descent: if your model is too small to ever hit the interpolation threshold—like was the case in ye olden days of ML—you never get epoch-wise double descent. Third, notice the log scale on the y axis: you have to train for quite a while to start seeing this phenomenon.
Finally, sample-wise non-monotonicity—Preetum et al. find a regime where increasing the amount of training data by four and a half times actually increases test loss (!):
What’s happening here is that more data increases the amount of model capacity/number of training epochs necessary to reach zero training error, which pushes out the interpolation threshold such that you can regress from the modern (interpolation) regime back into the classical (bias-variance tradeoff) regime, decreasing performance.
Additionally, another thing which Preetum et al. point out which I think is worth talking about here is the impact of label noise. Preetum et al. find that increasing label noise significantly exaggerates the test error peak around the interpolation threshold. Why might this be the case? Well, if we think about the inductive biases story from earlier, greater label noise means that near the interpolation threshold SGD is forced to find the one model which fits all of the noise—which is likely to be pretty bad since it has to model a bunch of noise. After the interpolation threshold, however, SGD is able to pick between many models which fit the noise and select one that does so in the simplest way such that you get good test performance.
Final comments
I’m quite excited about “Deep Double Descent,” but it still leaves what is in my opinion the most important question unanswered, which is: what exactly are the magical inductive biases of modern ML that make interpolation work so well?
One proposal I am aware of is the work of Keskar et al., who argue that SGD gets its good generalization properties from the fact that it finds “shallow” as opposed to “sharp” minima. The basic insight is that SGD tends to jump out of minima without broad basins around them and only really settle into minima with large attractors, which tend to be the exact sort of minima that generalize. Keskar et al. use the following diagram to explain this phenomena:
The more recent work of Dinh et al. in “Sharp Minima Can Generalize For Deep Nets,” however, calls the whole shallow vs. sharp minima hypothesis into question, arguing that deep networks have really weird geometry that doesn’t necessarily work the way Keskar et al. want it to. (EDIT: Maybe not. See this comment for an explanation of why Dinh et al. doesn’t necessarily rule out the shallow vs. sharp minima hypothesis.)
Another idea that might help here is Frankle and Carbin’s “Lottery Ticket Hypothesis,” which postulates that large neural networks work well because they are likely to contain random subnetworks at initialization (what they call “winning tickets”) which are already quite close to the final policy (at least in terms of being highly amenable to particularly effective training). My guess as to how double descent works if the Lottery Tickets Hypothesis is true is that in the interpolation regime SGD gets to just focus on the wining tickets and ignore the others—since it doesn’t have to use the full model capacity—whereas on the interpolation threshold SGD is forced to make use of the full network (to get the full model capacity), not just the winning tickets, which hurts generalization.
That’s just speculation on my part, however—we still don’t really understand the inductive biases of our models, despite the fact that, as double descent shows, inductive biases are the reason that modern ML (that is, the interpolation regime) works as well as it does. Furthermore, as I noted previously, inductive biases are highly relevant to the likelihood of possible dangerous phenomenon such as mesa-optimization and pseudo-alignment. Thus, it seems quite important to me to do further work in this area and really understand our models’ inductive biases, and I applaud Preetum et al. for their exciting work here.
EDIT: I have now written a follow-up to this post talking more about why I think double descent is important titled “Inductive biases stick around.”
- A transparency and interpretability tech tree by 16 Jun 2022 23:44 UTC; 163 points) (
- How do we become confident in the safety of a machine learning system? by 8 Nov 2021 22:49 UTC; 133 points) (
- AI Alignment 2018-19 Review by 28 Jan 2020 2:19 UTC; 126 points) (
- How likely is deceptive alignment? by 30 Aug 2022 19:34 UTC; 103 points) (
- 2019 Review: Voting Results! by 1 Feb 2021 3:10 UTC; 99 points) (
- “Deep Learning” Is Function Approximation by 21 Mar 2024 17:50 UTC; 98 points) (
- Why Neural Networks Generalise, and Why They Are (Kind of) Bayesian by 29 Dec 2020 13:33 UTC; 75 points) (
- Agents Over Cartesian World Models by 27 Apr 2021 2:06 UTC; 66 points) (
- Inductive biases stick around by 18 Dec 2019 19:52 UTC; 64 points) (
- Exploring the Lottery Ticket Hypothesis by 25 Apr 2023 20:06 UTC; 54 points) (
- How I’m thinking about GPT-N by 17 Jan 2022 17:11 UTC; 54 points) (
- My Overview of the AI Alignment Landscape: Threat Models by 25 Dec 2021 23:07 UTC; 53 points) (
- Understanding the Lottery Ticket Hypothesis by 14 May 2021 0:25 UTC; 50 points) (
- Getting up to Speed on the Speed Prior in 2022 by 28 Dec 2022 7:49 UTC; 36 points) (
- Musings on the Speed Prior by 2 Mar 2022 4:04 UTC; 33 points) (
- The Speed + Simplicity Prior is probably anti-deceptive by 27 Apr 2022 19:30 UTC; 28 points) (
- [AN #78] Formalizing power and instrumental convergence, and the end-of-year AI safety charity comparison by 26 Dec 2019 1:10 UTC; 26 points) (
- 16 Mar 2022 18:43 UTC; 25 points) 's comment on Book Launch: The Engines of Cognition by (
- Deep neural networks are not opaque. by 6 Jul 2022 18:03 UTC; 22 points) (
- DSLT 2. Why Neural Networks obey Occam’s Razor by 18 Jun 2023 0:23 UTC; 22 points) (
- Evidence Sets: Towards Inductive-Biases based Analysis of Prosaic AGI by 16 Dec 2021 22:41 UTC; 22 points) (
- [AN #77]: Double descent: a unification of statistical theory and modern ML practice by 18 Dec 2019 18:30 UTC; 21 points) (
- Quantitative cruxes in Alignment by 2 Jul 2023 20:38 UTC; 19 points) (
- Motivations, Natural Selection, and Curriculum Engineering by 16 Dec 2021 1:07 UTC; 16 points) (
- 10 Mar 2021 9:29 UTC; 10 points) 's comment on Daniel Kokotajlo’s Shortform by (
- 12 Jan 2021 19:25 UTC; 4 points) 's comment on Thread for making 2019 Review accountability commitments by (
- Notes on “How do we become confident in the safety of a machine learning system?” by 26 Oct 2023 3:13 UTC; 4 points) (
- 19 Feb 2023 23:13 UTC; 4 points) 's comment on Human beats SOTA Go AI by learning an adversarial policy by (
- 30 Dec 2020 3:48 UTC; 4 points) 's comment on Review Voting Thread by (
- 31 Dec 2022 21:54 UTC; 1 point) 's comment on A Mechanistic Interpretability Analysis of Grokking by (
- 3 Jan 2022 21:55 UTC; 1 point) 's comment on Regularization Causes Modularity Causes Generalization by (
- 13 May 2022 9:02 UTC; 1 point) 's comment on Deepmind’s Gato: Generalist Agent by (
This is perhaps the most striking fundamental discoveries of machine learning in the past 20 years, and Evan’s post is well-deserving of a nomination for explaining it to LW.
I’m no ML expert, but thanks to this post I feel like I have a basic grasp of some important ML theory. (It’s clearly written and has great graphs.) This is a big deal because this understanding of deep double descent has shaped my AI timelines to a noticeable degree.
I found this post interesting and helpful, and have used it as a mental hook on which to hang other things. Interpreting what’s going on with double descent, and what it implies, is tricky, and I’ll probably write a proper review at some point talking about that.